from torch import nn

from .SFPOOL import SoftPooling2D


def convert_maxpool2d_to_softpool2d(model):
    '''
    将模型内的maxpool2d转换成softpool2d， kernelsize可以更改，这样的修改可以提高模型性能
    :param model: 用于替换maxpool2d的模型
    :return:
    '''
    for child_name, child in model.named_children():
        if isinstance(child, nn.MaxPool2d):
            setattr(model, child_name,
                    SoftPooling2D(kernel_size=child.kernel_size, stride=child.stride, padding=child.padding,
                                  ceil_mode=child.ceil_mode))
        else:
            convert_maxpool2d_to_softpool2d(child)

'''
ceil_mode 类似于math中的ceil 和 floor操作
'''